import os
from pathlib import Path

import time
import numpy as np
from skimage import io
from skimage.filters import median
from skimage.morphology import square

import img_util.io
import img_util.util


def roi_crop_img(img_array, roi):
    """
    Take a 2D grayscale image and a list of rectangular ROIs, and crop the image using the ROIs.

    Parameters
    ----------
    img_array: ndarray
        a 2D image as numpy array
    roi: list (of dict)
        a list of ROI values, same format as ROI values generated from ImageJ ROI manager

    Returns
    -------
    imgs_roi: list
        a list of 2D image(s) as numpy array
    """

    # check that input image is 2D array (i.e., grayscale)
    if img_array.ndim != 2:
        raise ValueError("Error: Input image must be 2D array. Instead it has %d dimension"
                         % img_array.ndim)

    imgs_roi = []  # list container for output images

    for i, roi_i in enumerate(roi):

        # make sure ROI is of "rectangle" type
        if roi_i['type'] != 'rectangle':
            raise ValueError(
                "Error: ROI type must be 'rectangle'. Instead it is '%s' for ROI index: %i."
                % (roi_i['type'], i))

        # retrieve the vertices of the rectangular ROI
        l, t, w, h = roi_i['left'], roi_i['top'], roi_i['width'], roi_i['height']

        # make sure ROI is within image boundary
        roi_shape = (t + h, l + w)
        for roi_shape_i, img_shape_i in zip(roi_shape, img_array.shape):
            if roi_shape_i > img_shape_i:
                raise ValueError("Error: boundary for ROI (index: %s) is outside image boundary"
                                 % i)

        img_slc = img_array[t:t + h, l:l + w]  # slice out image using ROI
        imgs_roi.append(img_slc)

    return imgs_roi


def batch_roi_crop_img(img_path, roi_path, img_save_path, img_crop_suffix,
                       file_in_ext, file_out_ext, file_in_sep):
    """
    Take a path containing images, and a path containing ROIs (from ImageJ), and output cropped images based on ROIs.

    Parameters
    ----------
    img_path: str
        a string of the directory path to input images
    roi_path: str
        a string of the directory path to ROIs
    img_save_path: str
        a string of the directory path where cropped images will be saved
    img_crop_suffix: list of str
        a list of string of cropped images' filename suffix
    file_in_ext: str
        a string of input images' file extension (e.g., ".tif")
    file_out_ext: str
        a string of cropped images' file extension (e.g., ".pnf")
    file_in_sep: str
        a string, defining the separator for parsing input filename into subject number, slide number, and channel

    Returns
    -------
    Images are saved in img_save_path

    """

    p = Path(img_save_path)
    p.mkdir(exist_ok=True)

    for dirpath_i, dirnames_i, img_files in os.walk(img_path):

        sep = file_in_sep

        for img_file in img_files:
            fname_i, file_ext_i = os.path.splitext(img_file)

            if file_ext_i == file_in_ext:

                sbj_i, slide_i, channel_i = fname_i.split("_")
                print("Processing: " + sbj_i + sep + slide_i + sep + channel_i)

                img_array = io.imread(dirpath_i + "/" + img_file)  # read image as array
                found_roi = False  # ROI flag

                img_sv_path = Path(img_save_path + "/" + sbj_i)  # directory for saving image
                img_sv_path.mkdir(exist_ok=True)  # create directory if needed

                for dirpath_r, dirnames_r, roi_files in os.walk(roi_path):
                    for roi_file in roi_files:
                        fname_r, file_ext_r = os.path.splitext(roi_file)
                        sbj_r, slide_r = fname_r.split("_")[0], fname_r.split("_")[1]

                        if sbj_i == sbj_r and slide_i == slide_r:
                            roi_val = img_util.util.get_roi(dirpath_r + "/" + roi_file)[0]
                            img_crop = roi_crop_img(img_array, roi_val)

                            if len(img_crop) != len(img_crop_suffix):
                                raise ValueError("Number of ROIs do not match the number of suffixes")

                            for img_i, img_sfx in zip(img_crop, img_crop_suffix):
                                img_fname = sbj_i + sep + slide_i + sep + img_sfx + sep + channel_i + file_out_ext
                                img_save_file_path = str(img_sv_path) + "/" + img_fname
                                io.imsave(img_save_file_path, img_i)
                            found_roi = True

                if not found_roi:
                    raise ValueError("ROI file not found for " + img_file)

            else:
                pass


def get_thresh_masks_w_bkgnd(mask_path, img_basedir, excl_roi_dir, bkgnd_roi_channel, img_ext, bkgnd_rm_holes_area,
                             bkgnd_open_close_size, bkgnd_colour, bkgnd_fill_alpha, cell_diameter, mask_threshold_dict,
                             rm_small_bg_area=None):
    """
    Wrapper function: given input mask file, 1) get background ROI mask; 2) threshold input masks using original image's
    intensity; 3) add mask overlay to original image.

    Returns: 1) bkgnd_mask: 2D boolean array of background ROI mask;
    2) masks_thresh: 3-itmem list of 2.1) 2D array of thresholded masks; 2.2) list of average intensity for each mask;
    2.3) list of number of pixels for each mask;
    3) img_w_mask: RGB image of mask outlines overlain on original image.

    Parameters
    ----------
    mask_path: str
        Path fo mask file.
    img_basedir: str
        Directory to search for background image.
    excl_roi_dir: str
        Directory to search for ImageJ ROI file. Extracted ROI will be excluded from background ROI.
    bkgnd_roi_channel: str or None
        Channel to use for background (e.g., 'AF488'). Set to 'None' if not using background ROI.
    img_ext: str
        Extension of image files.
    bkgnd_rm_holes_area: int or float
        Fill holes smaller than this size (in pixels).
    bkgnd_open_close_size: int
        Size of structure element for performing morphological opening then closing (for smoothing edges).
    bkgnd_colour: str
        A string of named colour for matplotlib's "colors" module.
    bkgnd_fill_alpha: int or float
        A scalar for scaling mask fill transparency (0=transparent, 1=solid colour).
    cell_diameter: float
        Minimum diameter of cells in pixels.
    mask_threshold_dict: dict
        Dictionary containing mask threshold parameters.
    rm_small_bg_area: int or float or None
        Remove isolated background area smaller than this size (in pixels). Default=None.

    Returns
    -------

    """

    orig_img_path = img_util.io.get_orig_img_path(mask_path, img_basedir, img_ext)
    orig_img = io.imread(orig_img_path)

    # 1) Get background ROI mask
    if bkgnd_roi_channel is not None:
        bkgnd_img_path = img_util.io.get_bkgnd_img_path(mask_path, img_basedir, bkgnd_roi_channel, img_ext)
        bkgnd_img = median(io.imread(bkgnd_img_path), square(3))  # apply median filter to denoise background image
        bkgnd_mask = img_util.util.threshold_background_mask(bkgnd_img, bkgnd_rm_holes_area,
                                                             open_close_size=bkgnd_open_close_size, med_filter_size=3,
                                                             rm_small_bg_area=rm_small_bg_area)
    else:
        bkgnd_mask = np.ones(orig_img.shape, dtype='bool')

    # 1.1) Exclude ROI mask from background mask if necessary
    roi_excl_path = img_util.io.get_roi_excl_path(mask_path, excl_roi_dir)
    if len(roi_excl_path) != 0:
        roi_excl = img_util.util.rois_to_masks(roi_excl_path, bkgnd_mask.shape, fill_roi=True)
        bkgnd_mask[roi_excl] = False

    # 2) Get thresholded mask and mask-overlaid image
    # 2.1) Add background ROI mask to input image
    img_bg_roi = img_util.util.add_bg_mask_to_img(orig_img, bkgnd_mask, bkgnd_colour, bkgnd_fill_alpha)

    # 2.2) Threshold masks based on original image's intensity
    masks = np.load(str(mask_path), allow_pickle=True)
    masks[~bkgnd_mask] = 0  # Remove masks not in background ROI
    print("Thresholding masks for " + os.path.basename(mask_path))
    t = time.time()
    masks_thresh = img_util.util.threshold_masks(orig_img, masks, cell_diameter, mask_threshold_dict)
    print('    Thresholding took ' + str(round(time.time() - t, 2)) + ' sec')

    # 2.3) Overlay mask outlines to original image
    mask_outlines = img_util.util.masks_to_outlines(masks_thresh[0])
    outlines_rgb = img_util.util.binary_masks_to_colour(mask_outlines)
    img_w_mask = img_util.util.add_colour_mask_to_img(img_bg_roi, outlines_rgb)

    return bkgnd_mask, masks_thresh, img_w_mask
